AVL Trees¶

AVL trees are self-balancing binary search trees that keep the height difference between left and right subtrees of any node at most 1. This guarantees $O(\log n)$ time complexity for search, insertion, and deletion operations, ensuring efficient and predictable performance even after many updates. They automatically perform rotations after insertions or deletions to maintain balance, making them well-suited for applications requiring fast lookups and consistent operation times.

In [1]:
class Node:
    def __init__(self, value_ = 0):
        self.value = value_
        self.left = None
        self.right = None
        self.height = 1

    def __eq__(self, other: object):
        if other is None:
            return False
        return self.value == other.value

Function to insert a new node¶

In [2]:
def insert(self, value_: int):
    if value_ < self.value:
        if self.left is None:
            self.left = Node(value_)
        else:
            self.left.insert(value_)
    if value_ > self.value:
        if self.right is None:
            self.right = Node(value_)
        else:
            self.right.insert(value_)
    leftHeight = (0 if self.left is None else self.left.height)
    rightHeight = (0 if self.right is None else self.right.height)
    self.height = max(leftHeight, rightHeight) + 1

Function to traverse a tree¶

In [3]:
def traverse(self, indicator_ = "root"):
    print(f"{indicator_}: {self.value} {self.height}h")
    
    if not self.left is None:
        self.left.traverse("left")
    if not self.right is None:
        self.right.traverse("right")

Binding the methods to the class¶

In [4]:
Node.insert = insert
Node.traverse = traverse

Function to update the height of a node¶

In [5]:
def updateHeight(node: Node):
    leftHeight = (0 if node.left is None else node.left.height)
    rightHeight = (0 if node.right is None else node.right.height)
    node.height = max(leftHeight, rightHeight) + 1

Rotations in AVL Tree¶

graph RL;
    A1(A)
    A2(A)
    B1(B)
    B2(B)
    X1(X)
    X2(X)
    Y1(Y)
    Y2(Y)
    Z1(Z)
    Z2(Z)
    
    subgraph Tree1
        A1-->X1;
        A1-->B1;
        B1-->Y1;
        B1-->Z1;
    end

    subgraph Tree2
        B2-->A2;
        B2-->Z2;
        A2-->X2;
        A2-->Y2;
    end
    Tree1 --Left rotation--> Tree2
    Tree2 --Right rotation--> Tree1

Function to left-rotate a node¶

In [8]:
def leftRotate(rootNode: Node) -> Node:
    rightNode = rootNode.right
    rightLeftNode = rightNode.left
    rightNode.left = rootNode
    rootNode.right = rightLeftNode

    # updating the new heights
    updateHeight(rootNode)
    updateHeight(rightNode)

    return rightNode

Function to right-rotate a node¶

In [9]:
def rightRotate(rootNode: Node) -> Node:
    leftNode = rootNode.left
    leftRightNode = leftNode.left
    leftNode.right = rootNode
    rootNode.left = leftRightNode

    # updating the new heights
    updateHeight(rootNode)
    updateHeight(leftNode)

    return leftNode

Note: Both the left and right rotation functions run in $O(1)$ time complexity.

Algorithm¶

  1. Base Case: If the current node (rootNode) is None, return None.
  2. Recursively Balance Left Subtree: Call balanceTree on rootNode.left and update rootNode.left with the returned balanced subtree.
  3. Recursively Balance Right Subtree: Call balanceTree on rootNode.right and update rootNode.right with the returned balanced subtree.
  4. Update the Height of the Current Node:
    • Calculate the height of the left child: leftHeight = 0 if rootNode.left is None else rootNode.left.height
    • Calculate the height of the right child: rightHeight = 0 if rootNode.right is None else rootNode.right.height
    • Update current node’s height: rootNode.height = 1 + max(leftHeight, rightHeight)
  5. Calculate the Balance Factor of Current Node: balanceFactor = leftHeight - rightHeight
  6. Check if Current Node is Unbalanced:
    • If abs(balanceFactor) < 2, the node is balanced; no rotations needed. Return rootNode.
    • Else the node is unbalanced. Determine whether the imbalance is left-heavy or right-heavy:
      • Left-heavy: balanceFactor > 1
      • Right-heavy: balanceFactor < -1
  7. For Left-heavy Cases (balanceFactor > 1):
    • Compute left child’s balance factor:
    leftChildBalance = (height of left child's left subtree) - (height of left child's right subtree)
    
    • If leftChildBalance >= 0 (Left-Left Case):
      • Perform a right rotation on rootNode.
      • Return the new subtree root after rotation.
    • Else (Left-Right Case):
      • Perform a left rotation on rootNode.left.
      • Perform a right rotation on rootNode.
      • Return the new subtree root after rotations.
  8. For Right-heavy Cases (balanceFactor < -1):
    • Compute right child’s balance factor.
    rightChildBalance = (height of right child's left subtree) - (height of right child's right subtree)
    
    • If rightChildBalance <= 0 (Right-Right Case):
      • Perform a left rotation on rootNode.
      • Return the new subtree root after rotation.
    • Else (Right-Left Case):
      • Perform a right rotation on rootNode.right.
      • Perform a left rotation on rootNode.
      • Return the new subtree root after rotations.

Function to balance a BST¶

In [10]:
def balanceTree(rootNode: Node) -> Node:
    if rootNode is None:
        return
    rootNode.left = balanceTree(rootNode.left)
    rootNode.right = balanceTree(rootNode.right)

    # calculating the balance factor
    leftHeight = (0 if rootNode.left is None else rootNode.left.height)
    rightHeight = (0 if rootNode.right is None else rootNode.right.height)
    balanceFactor = leftHeight - rightHeight
    if abs(balanceFactor) >= 2:
        # right skewed
        if balanceFactor < 0:
            # print("Right skewed")
            return leftRotate(rootNode)
        # left skewed
        elif balanceFactor > 0:
            # print("Left skewed")
            return rightRotate(rootNode)
    return rootNode

Drive code¶

In [11]:
arr = [50, 10, 70, 60, 80, 75, 90]

root = Node(arr[0])
for i in range(1, len(arr)):
    root.insert(arr[i])
root.traverse()
root: 50 4h
left: 10 1h
right: 70 3h
left: 60 1h
right: 80 2h
left: 75 1h
right: 90 1h

Visualizing the unbalanced BST¶

graph TD
    50 --> 70;
    50 --> 10;
    70 --> 80;
    70 --> 60;
    80 --> 90;
    80 --> 75;

Balancing the BST¶

In [13]:
root = balanceTree(root)
root.traverse()
root: 70 3h
left: 50 2h
left: 10 1h
right: 60 1h
right: 80 2h
left: 75 1h
right: 90 1h

Visualizing the balanced BST¶

graph TD
    70 --> 50;
    70 --> 80;
    50 --> 10;
    50 --> 60;
    80 --> 75;
    80 --> 90;